{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Custom API Endpoints\n",
        "\n",
        "The [txtai API](https://neuml.github.io/txtai/api/) is a web-based service backed by [FastAPI](https://fastapi.tiangolo.com/). Semantic search, LLM orchestration and Language Model Workflows can all run through the API.\n",
        "\n",
        "While the API is extremely flexible and complex logic can be executed through YAML-driven workflows, some may prefer to create an endpoint in Python.\n",
        "\n",
        "This notebook introduces API extensions and shows how they can be used to define custom Python endpoints that interact with txtai applications."
      ],
      "metadata": {
        "id": "VGeVB8M41jqW"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Install dependencies\n",
        "\n",
        "Install `txtai` and all dependencies."
      ],
      "metadata": {
        "id": "ZQrHIw351lwE"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "R0AqRP7v1hdr"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!pip install git+https://github.com/neuml/txtai#egg=txtai[api] datasets"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Define the extension\n",
        "\n",
        "First, we'll create an application that defines a persistent embeddings database and LLM. Then we'll combine those two into a RAG endpoint through the API."
      ],
      "metadata": {
        "id": "xmPN8RDF1pXd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile app.yml\n",
        "\n",
        "# Embeddings index\n",
        "writable: true\n",
        "embeddings:\n",
        "  hybrid: true\n",
        "  content: true\n",
        "\n",
        "# LLM pipeline\n",
        "llm:\n",
        "  path: google/flan-t5-large\n",
        "  torch_dtype: torch.bfloat16"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XZ7vPBIs1rGZ",
        "outputId": "b5cf95f1-1a99-4839-ae9b-9141922bd248"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Writing app.yml\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "The code below creates an API endpoint at `/rag`. This is a `GET` endpoint that takes a `text` parameter as input."
      ],
      "metadata": {
        "id": "syd1PZ621sok"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile rag.py\n",
        "from fastapi import APIRouter\n",
        "from txtai.api import application, Extension\n",
        "\n",
        "\n",
        "class RAG(Extension):\n",
        "    \"\"\"\n",
        "    API extension\n",
        "    \"\"\"\n",
        "\n",
        "    def __call__(self, app):\n",
        "        app.include_router(RAGRouter().router)\n",
        "\n",
        "\n",
        "class RAGRouter:\n",
        "    \"\"\"\n",
        "    API router\n",
        "    \"\"\"\n",
        "\n",
        "    router = APIRouter()\n",
        "\n",
        "    @staticmethod\n",
        "    @router.get(\"/rag\")\n",
        "    def rag(text: str):\n",
        "        \"\"\"\n",
        "        Runs a retrieval augmented generation (RAG) pipeline.\n",
        "\n",
        "        Args:\n",
        "            text: input text\n",
        "\n",
        "        Returns:\n",
        "            response\n",
        "        \"\"\"\n",
        "\n",
        "        # Run embeddings search\n",
        "        results = application.get().search(text, 3)\n",
        "        context = \" \".join([x[\"text\"] for x in results])\n",
        "\n",
        "        prompt = f\"\"\"\n",
        "        Answer the following question using only the context below.\n",
        "\n",
        "        Question: {text}\n",
        "        Context: {context}\n",
        "        \"\"\"\n",
        "\n",
        "        return {\n",
        "            \"response\": application.get().pipeline(\"llm\", (prompt,))\n",
        "        }"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zXERt7Vw1ujq",
        "outputId": "2c680298-895b-419c-967d-70030265f5a6"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Writing rag.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Start the API instance\n",
        "\n",
        "Let's start the API with the RAG extension."
      ],
      "metadata": {
        "id": "p7vl6_9i1w39"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!CONFIG=app.yml EXTENSIONS=rag.RAG nohup uvicorn \"txtai.api:app\" &> api.log &\n",
        "!sleep 60"
      ],
      "metadata": {
        "id": "FRif4lhW1y8m"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Create the embeddings database\n",
        "\n",
        "Next, we'll create the embeddings database using the `ag_news` dataset. This is a set of news stories from the mid 2000s."
      ],
      "metadata": {
        "id": "FTdkEDa0106G"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import requests\n",
        "\n",
        "ds = load_dataset(\"ag_news\", split=\"train\")\n",
        "\n",
        "# API endpoint\n",
        "url = \"http://localhost:8000\"\n",
        "headers = {\"Content-Type\": \"application/json\"}\n",
        "\n",
        "# Add data\n",
        "batch = []\n",
        "for text in ds[\"text\"]:\n",
        "  batch.append({\"text\": text})\n",
        "  if len(batch) == 4096:\n",
        "    requests.post(f\"{url}/add\", headers=headers, json=batch, timeout=120)\n",
        "    batch = []\n",
        "\n",
        "if batch:\n",
        "    requests.post(f\"{url}/add\", headers=headers, json=batch, timeout=120)\n",
        "\n",
        "# Build index\n",
        "index = requests.get(f\"{url}/index\")"
      ],
      "metadata": {
        "id": "Ns6BKNQQ13FA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Run queries\n",
        "\n",
        "Now that we have a knowledge source indexed, let's run a set of queries. The code below defines a method that calls the `/rag` endpoint and retrieves the response. Keep in mind this dataset is from 2004.\n",
        "\n",
        "While the Python Requests library is used in this notebook, this is a simple web endpoint that can be called from any programming language."
      ],
      "metadata": {
        "id": "_wGvCWsP17it"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def rag(text):\n",
        "    return requests.get(f\"{url}/rag?text={text}\").json()[\"response\"]\n",
        "\n",
        "rag(\"Who is the current President?\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 39
        },
        "id": "_WbFu64L15Ch",
        "outputId": "3d631fe8-d1d3-4437-bf64-9248599caff9"
      },
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'George W. Bush'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "rag(\"Who lost the presidential election?\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 39
        },
        "id": "YtJ7LJ_819vw",
        "outputId": "e102b060-edb3-483c-98f9-50892e5e6c70"
      },
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'John Kerry'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "rag(\"Who won the World Series?\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 39
        },
        "id": "BlYDMTj41_QL",
        "outputId": "4f58fb40-2e75-4248-8065-5efc969fdd0e"
      },
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'Boston'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 16
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "rag(\"Who did the Red Sox beat to win the world series?\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 39
        },
        "id": "XMHLmQ532ApE",
        "outputId": "4bf5c7fa-dd42-43e3-b473-2df9d2c64d29"
      },
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'Cardinals'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "rag(\"What major hurricane hit the USA?\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 39
        },
        "id": "pTMqQSx82B_h",
        "outputId": "9ee72bc9-664b-407d-ac65-95f1a09a2cb2"
      },
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'Charley'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "rag(\"What mobile phone manufacturer has the largest current marketshare?\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 39
        },
        "id": "BV99h7272DVj",
        "outputId": "20602f12-09fe-4a44-a3f4-1797885e9d22"
      },
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'Nokia'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 19
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Wrapping up\n",
        "\n",
        "This notebook showed how a txtai application can be extended with custom endpoints in Python. While applications have a robust workflow framework, it may be preferable to write complex logic in Python and this method enables that."
      ],
      "metadata": {
        "id": "oPwgCgBc2Er2"
      }
    }
  ]
}